from bluepyopt.ephys.responses import TimeVoltageResponse
import json
import numpy as np
from bluepyparallel import init_parallel_factory
import yaml
from functools import partial

from pathlib import Path
import matplotlib.pyplot as plt
import pickle

# from emodel_generalisation.utils import plot_traces
from emodel_generalisation.model.modifiers import synth_soma, synth_axon
from emodel_generalisation.mcmc import load_chains
from emodel_generalisation.mcmc import save_selected_emodels
import logging
from itertools import cycle
import json
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
from emodel_generalisation.model.access_point import AccessPoint
from emodel_generalisation.model.evaluation import feature_evaluation
from emodel_generalisation.utils import get_combo_hash
from emodel_generalisation.utils import get_feature_df
from datareuse import Reuse
import extra_features

if __name__ == "__main__":
    parallel_factory = init_parallel_factory("multiprocessing")
    name = "runaway"
    access_point = AccessPoint(
        emodel_dir=".",
        recipes_path="config/recipes.json",
        final_path=f"final_{name}.json",
        with_seeds=True,
    )
    exemplar_data = yaml.safe_load(open("../../mcmc_run/exemplar_data.yaml"))
    access_point.morph_path = exemplar_data["paths"]["all"]
    access_point.settings["morph_modifiers"] = [
        partial(synth_soma, params=exemplar_data["soma"], scale=1.0),
        partial(synth_axon, params=exemplar_data["ais"]["popt"], scale=1.0),
    ]

    df = pd.DataFrame()
    final = json.load(open(f"final_{name}.json"))
    emodel = "simplest_0"
    values = np.array([0.5, 0.75, 1.0, 1.25, 1.5]) * 5.303246551498367e-05
    for i, val in enumerate(values):
        df.loc[i, "name"] = f"gpas_{i}"
        df.loc[i, "emodel"] = emodel
        df.loc[i, "new_parameters"] = json.dumps({"g_pas.all": val})

    print(df)
    with Reuse(f"gpas_scan_eval_{name}.csv") as reuse:
        df = reuse(
            feature_evaluation,
            df,
            access_point,
            parallel_factory=parallel_factory,
            trace_data_path="traces",
            # record_ions_and_currents=True,
        )

    d = get_feature_df(df)

    print(d)
    columns = [c for c in d.columns if len(c.split(".")[0].split("_")) == 4]
    cmap = plt.get_cmap("coolwarm")
    colors = cmap(np.linspace(0, 1, len(d)))
    plt.figure(figsize=(5, 3))
    for gid, c in zip(d.index, colors):
        _df = pd.DataFrame()
        i = 0
        for col in columns:
            if col.endswith(".voltage_base"):
                _df.loc[i, "burst_number"] = d.loc[
                    gid, ".".join(col.split(".")[:-1] + ["all_burst_number"])
                ]
                _df.loc[i, "holding_voltage"] = d.loc[gid, col]
                i += 1
        print(_df)
        if gid == 2:
            c = "C2"
        _df.plot(
            x="holding_voltage",
            y="burst_number",
            ax=plt.gca(),
            marker="+",
            label=f"gpas={values[gid]}",
            c=c,
        )
    plt.xlabel("holding_voltage")
    plt.ylabel("burst_number")
    plt.legend()
    plt.axis([-80, -50, 0, 37])
    plt.tight_layout()
    plt.savefig(f"gpas_burst_scan_{name}.pdf")
    plt.close()
